{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "47ce31dc-0635-4049-821a-b3fd6102ae7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import jieba\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# 读取数据集，这里是直接联网读取，也可以通过下载文件，再读取\n",
    "data_dir = 'https://mirror.coggle.club/dataset/coggle-competition/'\n",
    "train_data = pd.read_csv(data_dir + 'intent-classify/train.csv', sep='\\t', header=None)\n",
    "test_data = pd.read_csv(data_dir + 'intent-classify/test.csv', sep='\\t', header=None)\n",
    "\n",
    "cn_stopwords = pd.read_csv('https://mirror.coggle.club/stopwords/baidu_stopwords.txt', header=None)[0].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a08e8b45-9256-42d3-b89e-bb76dc94d9c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/feature_extraction/text.py:525: UserWarning: The parameter 'token_pattern' will not be used since 'tokenizer' is not None'\n",
      "  warnings.warn(\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.626 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/feature_extraction/text.py:408: UserWarning: Your stop_words may be inconsistent with your preprocessing. Tokenizing the stop words generated tokens [\"'\", 'a', 'ain', 'aren', 'c', 'couldn', 'd', 'didn', 'doesn', 'don', 'hadn', 'hasn', 'haven', 'i', 'isn', 'll', 'm', 'mon', 's', 'shouldn', 't', 've', 'wasn', 'weren', 'won', 'wouldn', '下', '不', '为什', '什', '今', '使', '先', '却', '只', '唷', '啪', '喔', '天', '好', '後', '最', '漫', '然', '特', '特别', '见', '设', '说', '达', '面', '麽', '－'] not in stop_words.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "\n",
    "tfidf = TfidfVectorizer(\n",
    "    tokenizer=jieba.lcut,\n",
    "    stop_words=list(cn_stopwords)\n",
    ")\n",
    "train_tfidf = tfidf.fit_transform(train_data[0])\n",
    "test_tfidf = tfidf.transform(test_data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cfcc582a-80ae-458f-bedf-2233a15eeca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.svm import LinearSVC\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.model_selection import cross_val_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e710df9c-7743-45d4-8b40-9b7bec898121",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                       precision    recall  f1-score   support\n",
      "\n",
      "         Alarm-Update       0.98      0.93      0.96      1264\n",
      "           Audio-Play       0.74      0.50      0.60       226\n",
      "       Calendar-Query       0.99      0.95      0.97      1214\n",
      "        FilmTele-Play       0.70      0.93      0.80      1355\n",
      "HomeAppliance-Control       0.94      0.97      0.96      1215\n",
      "           Music-Play       0.88      0.87      0.87      1304\n",
      "                Other       0.39      0.07      0.11       214\n",
      "         Radio-Listen       0.94      0.89      0.91      1285\n",
      "       TVProgram-Play       0.72      0.45      0.55       240\n",
      "         Travel-Query       0.92      0.96      0.94      1220\n",
      "           Video-Play       0.90      0.87      0.89      1334\n",
      "        Weather-Query       0.92      0.96      0.94      1229\n",
      "\n",
      "             accuracy                           0.89     12100\n",
      "            macro avg       0.84      0.78      0.79     12100\n",
      "         weighted avg       0.89      0.89      0.89     12100\n",
      "\n"
     ]
    }
   ],
   "source": [
    "cv_pred = cross_val_predict(\n",
    "    LogisticRegression(),\n",
    "    train_tfidf, train_data[1]\n",
    ")\n",
    "print(classification_report(train_data[1], cv_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7f89dd4e-1374-419c-9ac1-df2435547997",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n",
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                       precision    recall  f1-score   support\n",
      "\n",
      "         Alarm-Update       0.97      0.95      0.96      1264\n",
      "           Audio-Play       0.64      0.71      0.67       226\n",
      "       Calendar-Query       0.98      0.97      0.98      1214\n",
      "        FilmTele-Play       0.81      0.89      0.85      1355\n",
      "HomeAppliance-Control       0.97      0.98      0.98      1215\n",
      "           Music-Play       0.90      0.89      0.89      1304\n",
      "                Other       0.31      0.25      0.27       214\n",
      "         Radio-Listen       0.94      0.90      0.92      1285\n",
      "       TVProgram-Play       0.66      0.62      0.64       240\n",
      "         Travel-Query       0.95      0.98      0.97      1220\n",
      "           Video-Play       0.92      0.88      0.90      1334\n",
      "        Weather-Query       0.96      0.97      0.97      1229\n",
      "\n",
      "             accuracy                           0.91     12100\n",
      "            macro avg       0.83      0.83      0.83     12100\n",
      "         weighted avg       0.91      0.91      0.91     12100\n",
      "\n"
     ]
    }
   ],
   "source": [
    "cv_pred = cross_val_predict(\n",
    "    LinearSVC(),\n",
    "    train_tfidf, train_data[1]\n",
    ")\n",
    "print(classification_report(train_data[1], cv_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7cff5d84-826e-484b-a555-6da50bda667b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                       precision    recall  f1-score   support\n",
      "\n",
      "         Alarm-Update       0.83      0.92      0.87      1264\n",
      "           Audio-Play       0.55      0.63      0.59       226\n",
      "       Calendar-Query       0.81      0.96      0.88      1214\n",
      "        FilmTele-Play       0.80      0.79      0.80      1355\n",
      "HomeAppliance-Control       0.92      0.97      0.94      1215\n",
      "           Music-Play       0.83      0.82      0.83      1304\n",
      "                Other       0.20      0.25      0.22       214\n",
      "         Radio-Listen       0.91      0.83      0.87      1285\n",
      "       TVProgram-Play       0.55      0.39      0.46       240\n",
      "         Travel-Query       0.94      0.90      0.92      1220\n",
      "           Video-Play       0.87      0.73      0.79      1334\n",
      "        Weather-Query       0.92      0.89      0.90      1229\n",
      "\n",
      "             accuracy                           0.84     12100\n",
      "            macro avg       0.76      0.76      0.76     12100\n",
      "         weighted avg       0.84      0.84      0.84     12100\n",
      "\n"
     ]
    }
   ],
   "source": [
    "cv_pred = cross_val_predict(\n",
    "    KNeighborsClassifier(),\n",
    "    train_tfidf, train_data[1]\n",
    ")\n",
    "print(classification_report(train_data[1], cv_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "06fe4b6a-e4bb-4116-b6c9-1f22a37c368a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lyz/anaconda3/envs/py311/lib/python3.11/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "model = LinearSVC()\n",
    "model.fit(train_tfidf, train_data[1])\n",
    "pd.DataFrame({\n",
    "    'ID':range(1, len(test_data) + 1),\n",
    "    \"Target\":model.predict(test_tfidf)\n",
    "}).to_csv('nlp_submit.csv', index=None)\n",
    "# 可以提交到"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e1787ac-fdc5-461f-92bc-d6361cab525d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
